"""
Phase 3: Downgrade Prediction — Logit, Cloglog, Ordered Logit
================================================================
3a. Any-downgrade logit (~20 events) — sequential specs
3b. Loss-of-safe cloglog (~7 events) — rare events
3c. Ordered logit for rating categories (all variation)

Adapted from crises/scripts/phase8_reviewer_response.py

Output: table4_downgrade_logit.md, table5_safe_loss_cloglog.md,
        table6_ordered_logit.md, phase3_results.csv
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.optimize import minimize
from scipy import stats as sp_stats
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_asset_cliff")
ROOT_DIR = PROJECT_DIR.parent
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def compute_auc(y_true, y_scores):
    """Wilcoxon-Mann-Whitney AUC."""
    order = np.argsort(-y_scores)
    y_sorted = y_true[order]
    n_pos = y_true.sum()
    n_neg = len(y_true) - n_pos
    if n_pos == 0 or n_neg == 0:
        return np.nan
    tp = 0
    auc_sum = 0
    for i in range(len(y_sorted)):
        if y_sorted[i] == 1:
            tp += 1
        else:
            auc_sum += tp
    return auc_sum / (n_pos * n_neg)


def run_logit(df, y_var, x_vars, label, min_events=5):
    """Pooled logit with MFX, AUC, Brier score."""
    cols = [y_var] + x_vars + ['iso3']
    sub = df[cols].dropna()
    if len(sub) < 30:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    y = sub[y_var].values.astype(float)
    X = np.column_stack([np.ones(len(sub)), sub[x_vars].values.astype(float)])
    n, k = X.shape

    if y.sum() < min_events or (1 - y).sum() < min_events:
        print(f"  {label}: insufficient events ({y.sum():.0f}), skipping")
        return None

    # Standardize for numerical stability
    x_means = X[:, 1:].mean(axis=0)
    x_stds = X[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_std = X.copy()
    X_std[:, 1:] = (X[:, 1:] - x_means) / x_stds

    def neg_ll(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))

    def grad(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        return -X_std.T @ (y - p)

    try:
        result = minimize(neg_ll, np.zeros(k), jac=grad,
                          method='BFGS', options={'maxiter': 1000, 'gtol': 1e-6})
        beta_std = result.x

        # Transform back
        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)

        z = X @ beta
        z = np.clip(z, -30, 30)
        p_hat = 1 / (1 + np.exp(-z))
        p_hat = np.clip(p_hat, 1e-12, 1 - 1e-12)

        # Standard errors
        W = p_hat * (1 - p_hat)
        H = X.T @ (X * W[:, None])
        try:
            V = np.linalg.inv(H)
            se = np.sqrt(np.diag(V))
        except np.linalg.LinAlgError:
            se = np.full(k, np.nan)

        t_stats = beta / se
        pvalues = 2 * (1 - sp_stats.norm.cdf(np.abs(t_stats)))

        # Pseudo-R²
        ll_model = -result.fun
        p_bar = y.mean()
        ll_null = n * (p_bar * np.log(p_bar + 1e-12) + (1 - p_bar) * np.log(1 - p_bar + 1e-12))
        pseudo_r2 = 1 - ll_model / ll_null if ll_null != 0 else 0

        # MFX at mean
        z_mean = X.mean(axis=0) @ beta
        p_mean = 1 / (1 + np.exp(-z_mean))
        mfx = beta[1:] * p_mean * (1 - p_mean)

        # Scoring
        auc = compute_auc(y, p_hat)
        brier = np.mean((y - p_hat) ** 2)

    except Exception as e:
        print(f"  {label}: logit failed ({e}), skipping")
        return None

    res = {
        'model': label, 'estimator': 'logit',
        'n_obs': n, 'n_countries': sub['iso3'].nunique(),
        'n_events': int(y.sum()),
        'pseudo_r2': pseudo_r2, 'auc': auc, 'brier': brier,
        'p_hat': p_hat, 'y': y,
    }

    print(f"\n  {label} (N={n}, events={int(y.sum())}, "
          f"Pseudo-R²={pseudo_r2:.4f}, AUC={auc:.3f})")
    for i, name in enumerate(x_vars):
        sig = stars(pvalues[i + 1])
        print(f"    {name:30s} β={beta[i+1]:8.4f} (se={se[i+1]:.4f}) {sig}  "
              f"[MFX={mfx[i]:.5f}]")
        res[f'{name}_beta'] = beta[i + 1]
        res[f'{name}_se'] = se[i + 1]
        res[f'{name}_p'] = pvalues[i + 1]
        res[f'{name}_mfx'] = mfx[i]

    return res


def run_cloglog(df, y_var, x_vars, label, min_events=3):
    """Complementary log-log for rare events."""
    cols = [y_var] + x_vars + ['iso3']
    sub = df[cols].dropna()
    if len(sub) < 30:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    y = sub[y_var].values.astype(float)
    X = np.column_stack([np.ones(len(sub)), sub[x_vars].values.astype(float)])
    n, k = X.shape

    if y.sum() < min_events:
        print(f"  {label}: insufficient events ({y.sum():.0f}), skipping")
        return None

    x_means = X[:, 1:].mean(axis=0)
    x_stds = X[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_std = X.copy()
    X_std[:, 1:] = (X[:, 1:] - x_means) / x_stds

    def neg_ll(beta):
        eta = X_std @ beta
        eta = np.clip(eta, -30, 30)
        p = 1 - np.exp(-np.exp(eta))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))

    def grad(beta):
        eta = X_std @ beta
        eta = np.clip(eta, -30, 30)
        exp_eta = np.exp(eta)
        p = 1 - np.exp(-exp_eta)
        p = np.clip(p, 1e-12, 1 - 1e-12)
        dp = np.exp(eta - exp_eta)
        dp = np.clip(dp, 1e-12, 1e12)
        w = (y / p - (1 - y) / (1 - p)) * dp
        return -X_std.T @ w

    try:
        result = minimize(neg_ll, np.zeros(k), jac=grad,
                          method='BFGS', options={'maxiter': 2000, 'gtol': 1e-5})
        beta_std = result.x

        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)

        eta = X @ beta
        eta = np.clip(eta, -30, 30)
        p_hat = 1 - np.exp(-np.exp(eta))
        p_hat = np.clip(p_hat, 1e-12, 1 - 1e-12)

        exp_eta = np.exp(eta)
        dp = np.exp(eta - exp_eta)
        dp = np.clip(dp, 1e-12, 1e12)
        W = dp**2 / (p_hat * (1 - p_hat))
        W = np.clip(W, 1e-12, 1e12)
        H = X.T @ (X * W[:, None])
        try:
            V = np.linalg.inv(H)
            se = np.sqrt(np.abs(np.diag(V)))
        except np.linalg.LinAlgError:
            se = np.full(k, np.nan)

        t_stats = beta / se
        pvalues = 2 * (1 - sp_stats.norm.cdf(np.abs(t_stats)))

        ll_model = -result.fun
        p_bar = y.mean()
        ll_null = n * (p_bar * np.log(p_bar + 1e-12) + (1 - p_bar) * np.log(1 - p_bar + 1e-12))
        pseudo_r2 = 1 - ll_model / ll_null if ll_null != 0 else 0

        eta_mean = X.mean(axis=0) @ beta
        exp_eta_m = np.exp(eta_mean)
        dp_mean = np.exp(eta_mean - exp_eta_m)
        mfx = beta[1:] * dp_mean

        auc = compute_auc(y, p_hat)
        brier = np.mean((y - p_hat) ** 2)

    except Exception as e:
        print(f"  {label}: cloglog failed ({e}), skipping")
        return None

    res = {
        'model': label, 'estimator': 'cloglog',
        'n_obs': n, 'n_countries': sub['iso3'].nunique(),
        'n_events': int(y.sum()),
        'pseudo_r2': pseudo_r2, 'auc': auc, 'brier': brier,
    }

    print(f"\n  {label} (N={n}, events={int(y.sum())}, "
          f"Pseudo-R²={pseudo_r2:.4f}, AUC={auc:.3f}) [Cloglog]")
    for i, name in enumerate(x_vars):
        sig = stars(pvalues[i + 1])
        print(f"    {name:30s} β={beta[i+1]:8.4f} (se={se[i+1]:.4f}) {sig}  "
              f"[MFX={mfx[i]:.5f}]")
        res[f'{name}_beta'] = beta[i + 1]
        res[f'{name}_se'] = se[i + 1]
        res[f'{name}_p'] = pvalues[i + 1]
        res[f'{name}_mfx'] = mfx[i]

    return res


def run_ordered_logit(df, y_var, x_vars, label):
    """Custom ordered logit for rating categories (0-4)."""
    cols = [y_var] + x_vars + ['iso3']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    y = sub[y_var].values.astype(int)
    X_raw = sub[x_vars].values.astype(float)
    n, p = X_raw.shape

    categories = sorted(np.unique(y))
    K = len(categories)
    if K < 3:
        print(f"  {label}: too few categories ({K}), skipping")
        return None

    # Map y to 0..K-1
    cat_map = {c: i for i, c in enumerate(categories)}
    y_mapped = np.array([cat_map[yi] for yi in y])

    # Standardize X
    x_means = X_raw.mean(axis=0)
    x_stds = X_raw.std(axis=0)
    x_stds[x_stds == 0] = 1
    X_std = (X_raw - x_means) / x_stds

    # Parameters: beta (p,) + cutpoints (K-1,)
    # P(Y <= j) = 1 / (1 + exp(-(alpha_j - X beta)))
    n_params = p + (K - 1)

    def neg_ll(params):
        beta = params[:p]
        alphas = params[p:]
        xb = X_std @ beta
        ll = 0.0
        for i in range(n):
            j = y_mapped[i]
            if j == 0:
                eta = alphas[0] - xb[i]
                eta = np.clip(eta, -30, 30)
                prob = 1 / (1 + np.exp(-eta))
            elif j == K - 1:
                eta = alphas[K - 2] - xb[i]
                eta = np.clip(eta, -30, 30)
                prob = 1 - 1 / (1 + np.exp(-eta))
            else:
                eta_hi = alphas[j] - xb[i]
                eta_lo = alphas[j - 1] - xb[i]
                eta_hi = np.clip(eta_hi, -30, 30)
                eta_lo = np.clip(eta_lo, -30, 30)
                prob = 1 / (1 + np.exp(-eta_hi)) - 1 / (1 + np.exp(-eta_lo))
            prob = max(prob, 1e-12)
            ll += np.log(prob)
        return -ll

    # Initial: evenly spaced cutpoints
    init_alphas = np.linspace(-2, 2, K - 1)
    init_params = np.concatenate([np.zeros(p), init_alphas])

    try:
        result = minimize(neg_ll, init_params, method='BFGS',
                          options={'maxiter': 2000, 'gtol': 1e-5})
        params = result.x
        beta_std = params[:p]
        alphas = params[p:]

        # Transform beta back to original scale
        beta = beta_std / x_stds

        # Numerical Hessian for SEs
        eps = 1e-5
        H = np.zeros((n_params, n_params))
        f0 = neg_ll(result.x)
        for i_p in range(n_params):
            for j_p in range(i_p, n_params):
                e_i = np.zeros(n_params)
                e_j = np.zeros(n_params)
                e_i[i_p] = eps
                e_j[j_p] = eps
                f_pp = neg_ll(result.x + e_i + e_j)
                f_pm = neg_ll(result.x + e_i - e_j)
                f_mp = neg_ll(result.x - e_i + e_j)
                f_mm = neg_ll(result.x - e_i - e_j)
                H[i_p, j_p] = (f_pp - f_pm - f_mp + f_mm) / (4 * eps ** 2)
                H[j_p, i_p] = H[i_p, j_p]

        try:
            V = np.linalg.inv(H)
            se_all = np.sqrt(np.abs(np.diag(V)))
        except np.linalg.LinAlgError:
            se_all = np.full(n_params, np.nan)

        # SEs for original-scale beta
        se_beta = se_all[:p] / x_stds

        pvalues = 2 * (1 - sp_stats.norm.cdf(np.abs(beta / se_beta)))

        # Pseudo-R²
        ll_model = -result.fun
        freq = np.bincount(y_mapped, minlength=K) / n
        freq = np.clip(freq, 1e-12, 1)
        ll_null = np.sum([np.log(freq[y_mapped[i]]) for i in range(n)])
        pseudo_r2 = 1 - ll_model / ll_null if ll_null != 0 else 0

    except Exception as e:
        print(f"  {label}: ordered logit failed ({e}), skipping")
        return None

    res = {
        'model': label, 'estimator': 'ordered_logit',
        'n_obs': n, 'n_countries': sub['iso3'].nunique(),
        'n_categories': K,
        'pseudo_r2': pseudo_r2,
        'categories': [f"{categories[i]}" for i in range(K)],
    }

    print(f"\n  {label} (N={n}, K={K}, Pseudo-R²={pseudo_r2:.4f}) [Ordered Logit]")
    print(f"    Cutpoints: {', '.join(f'{a:.3f}' for a in alphas)}")
    for i, name in enumerate(x_vars):
        sig = stars(pvalues[i])
        print(f"    {name:30s} β={beta[i]:8.4f} (se={se_beta[i]:.4f}) {sig}")
        res[f'{name}_beta'] = beta[i]
        res[f'{name}_se'] = se_beta[i]
        res[f'{name}_p'] = pvalues[i]

    return res


def write_markdown_table(path, title, headers, rows, notes=None):
    lines = [f"### {title}", ""]
    lines.append("| " + " | ".join(headers) + " |")
    lines.append("|" + "|".join(["--:" if i > 0 else ":--" for i in range(len(headers))]) + "|")
    for row in rows:
        lines.append("| " + " | ".join(str(c) for c in row) + " |")
    if notes:
        lines.append("")
        lines.append(f"*{notes}*")
    lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")
    print(f"  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 3: Downgrade Prediction — Logit / Cloglog / Ordered Logit")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "cliff_panel.csv")
    print(f"Loaded: {df['iso3'].nunique()} countries, {len(df):,} obs")
    print(f"  downgrade_any events: {df['downgrade_any'].sum()}")
    print(f"  lost_safe events: {df['lost_safe'].sum()}")

    all_results = []

    # ====================================================================
    # 3a. Any-downgrade logit
    # ====================================================================
    print("\n" + "=" * 70)
    print("3a. ANY-DOWNGRADE LOGIT")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    fiscal_vars = ['govt_debt_gdp', 'exp_rev_gap', 'primary_bal_gdp', 'rgdp_growth']
    fiscal_vars = [v for v in fiscal_vars if v in df.columns]

    # Use lagged fiscal variables for prediction
    fiscal_lag = [f'{v}_lag' if f'{v}_lag' in df.columns else v for v in fiscal_vars]
    fiscal_lag = [v for v in fiscal_lag if v in df.columns]

    specs = [
        ("3a.1: Demographics only", demo_vars),
        ("3a.2: + OADR level", ['old_dep']),
        ("3a.3: + OADR spline(20%)", ['old_dep', 'oadr_spline_20']),
        ("3a.4: + Fiscal", ['old_dep'] + fiscal_lag),
    ]

    # Add r-g and inflation if available
    full_vars = ['old_dep'] + fiscal_lag
    if 'r_minus_g' in df.columns:
        full_vars.append('r_minus_g')
    if 'inflation' in df.columns:
        full_vars.append('inflation')
    specs.append(("3a.5: Full", full_vars))

    # Predetermined
    if 'oadr_plus10' in df.columns:
        specs.append(("3a.6: Predetermined (OADR+10)", ['oadr_plus10'] + fiscal_lag))

    logit_table_rows = []
    for label, xvars in specs:
        res = run_logit(df, 'downgrade_any', xvars, label, min_events=5)
        if res:
            all_results.append(res)
            # Build table row
            row = [label, str(res['n_obs']), str(res['n_events']),
                   f"{res['pseudo_r2']:.3f}", f"{res['auc']:.3f}"]
            # Add key coefficient
            if 'old_dep_beta' in res:
                sig = stars(res['old_dep_p'])
                row.append(f"{res['old_dep_beta']:.3f}{sig}")
            elif 'Z_1_beta' in res:
                sig = stars(res['Z_1_p'])
                row.append(f"{res['Z_1_beta']:.3f}{sig}")
            else:
                row.append("-")
            logit_table_rows.append(row)

    if logit_table_rows:
        write_markdown_table(
            TABLES_DIR / "table4_downgrade_logit.md",
            "Table 4: Downgrade Prediction — Logit Models",
            ["Specification", "N", "Events", "Pseudo-R²", "AUC", "Key Coef"],
            logit_table_rows,
            notes="Dependent variable: downgrade_any (=1 if rating decreased from prior year). "
                  "Key coef: OADR or Z₁ coefficient with significance stars."
        )

    # ====================================================================
    # 3b. Loss-of-safe cloglog
    # ====================================================================
    print("\n" + "=" * 70)
    print("3b. LOSS-OF-SAFE CLOGLOG")
    print("=" * 70)

    # Restrict to observations where lagged rating >= AA-
    safe_risk = df[df['rating_lag'] >= 18].copy()
    print(f"  At-risk sample (lagged rating >= AA-): {len(safe_risk):,} obs, "
          f"lost_safe events: {safe_risk['lost_safe'].sum()}")

    cloglog_table_rows = []
    for label, xvars in [
        ("3b.1: OADR", ['old_dep']),
        ("3b.2: OADR + fiscal", ['old_dep', 'govt_debt_gdp_lag', 'exp_rev_gap_lag']),
        ("3b.3: Full", ['old_dep', 'govt_debt_gdp_lag', 'exp_rev_gap_lag', 'rgdp_growth']),
    ]:
        xvars_avail = [v for v in xvars if v in safe_risk.columns]
        res = run_cloglog(safe_risk, 'lost_safe', xvars_avail, label, min_events=3)
        if res:
            all_results.append(res)
            row = [label, str(res['n_obs']), str(res['n_events']),
                   f"{res['pseudo_r2']:.3f}", f"{res['auc']:.3f}"]
            if 'old_dep_beta' in res:
                sig = stars(res['old_dep_p'])
                row.append(f"{res['old_dep_beta']:.3f}{sig}")
            else:
                row.append("-")
            cloglog_table_rows.append(row)

    if cloglog_table_rows:
        write_markdown_table(
            TABLES_DIR / "table5_safe_loss_cloglog.md",
            "Table 5: Loss of Safe Status — Cloglog Models",
            ["Specification", "N", "Events", "Pseudo-R²", "AUC", "OADR Coef"],
            cloglog_table_rows,
            notes="Restricted to country-years with lagged rating >= AA-. "
                  "Complementary log-log link function (rare events). "
                  "Low-power caveat: only 7 events."
        )

    # ====================================================================
    # 3c. Ordered logit for rating categories
    # ====================================================================
    print("\n" + "=" * 70)
    print("3c. ORDERED LOGIT FOR RATING CATEGORIES")
    print("=" * 70)

    ordered_table_rows = []
    for label, xvars in [
        ("3c.1: Demographics Z", demo_vars),
        ("3c.2: OADR", ['old_dep']),
        ("3c.3: OADR + fiscal", ['old_dep', 'govt_debt_gdp', 'exp_rev_gap']),
        ("3c.4: Full", ['old_dep', 'govt_debt_gdp', 'exp_rev_gap',
                         'rgdp_growth', 'inflation']),
    ]:
        xvars_avail = [v for v in xvars if v in df.columns]
        res = run_ordered_logit(df, 'rating_category', xvars_avail, label)
        if res:
            all_results.append(res)
            row = [label, str(res['n_obs']),
                   str(res['n_categories']), f"{res['pseudo_r2']:.3f}"]
            if 'old_dep_beta' in res:
                sig = stars(res['old_dep_p'])
                row.append(f"{res['old_dep_beta']:.3f}{sig}")
            elif 'Z_1_beta' in res:
                sig = stars(res['Z_1_p'])
                row.append(f"{res['Z_1_beta']:.3f}{sig}")
            else:
                row.append("-")
            ordered_table_rows.append(row)

    if ordered_table_rows:
        write_markdown_table(
            TABLES_DIR / "table6_ordered_logit.md",
            "Table 6: Rating Category — Ordered Logit",
            ["Specification", "N", "Categories", "Pseudo-R²", "Key Coef"],
            ordered_table_rows,
            notes="Categories: AAA (4), AA+ (3), AA (2), AA- (1), Below AA- (0). "
                  "Positive β = higher probability of higher category."
        )

    # ====================================================================
    # Save results
    # ====================================================================
    if all_results:
        results_df = pd.DataFrame(all_results)
        # Drop numpy arrays
        for col in ['p_hat', 'y', 'categories']:
            if col in results_df.columns:
                results_df = results_df.drop(columns=[col])
        results_df.to_csv(TABLES_DIR / "phase3_results.csv", index=False)
        print(f"\n  Saved: {TABLES_DIR / 'phase3_results.csv'}")

    print("\n" + "=" * 70)
    print("Phase 3 complete.")
    print("=" * 70)

    return all_results


if __name__ == "__main__":
    results = main()
